{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9B8F749E4C954AD8AFE6B96DD1F8B830",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# 批量归一化（BatchNormalization）\n",
    "$\\color{#FF0000}{BN层的作用机制也许是通过平滑隐藏层输入的分布，帮助随机梯度下降的进行，缓解随机梯度下降权重更新对后续层的负面影响。对sigmoid和tanh而言，放非线性激活之前，也许顺便还能缓解sigmoid/tanh的梯度衰减问题，而对ReLU而言，这个平滑作用经ReLU“扭曲”之后也许有所衰弱}$\n",
    "#### 对输入的标准化（浅层模型）\n",
    "处理后的任意一个特征在数据集中所有样本上的均值为0、标准差为1。  \n",
    "标准化处理输入数据使各个特征的分布相近\n",
    "#### 批量归一化（深度模型）\n",
    "利用小批量上的均值和标准差，不断调整神经网络中间输出，从而使整个神经网络在各层的中间输出的数值更稳定。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3E15ABE688814DCB811DF2D4D3A0BBDB",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 1.对全连接层做批量归一化\n",
    "位置：全连接层中的仿射变换和激活函数之间。  \n",
    "**全连接：**  \n",
    "$$\n",
    "\\boldsymbol{x} = \\boldsymbol{W\\boldsymbol{u} + \\boldsymbol{b}} \\\\\n",
    " output =\\phi(\\boldsymbol{x})\n",
    " $$   \n",
    "\n",
    "\n",
    "**批量归一化：**\n",
    "$$ \n",
    "output=\\phi(\\text{BN}(\\boldsymbol{x}))$$\n",
    "\n",
    "\n",
    "$$\n",
    "\\boldsymbol{y}^{(i)} = \\text{BN}(\\boldsymbol{x}^{(i)})\n",
    "$$\n",
    "\n",
    "\n",
    "\n",
    "$$\n",
    "\\boldsymbol{\\mu}_\\mathcal{B} \\leftarrow \\frac{1}{m}\\sum_{i = 1}^{m} \\boldsymbol{x}^{(i)},\n",
    "$$ \n",
    "i代表每一个样本，上面这个是在bachsize上做平均\n",
    "$$\n",
    "\\boldsymbol{\\sigma}_\\mathcal{B}^2 \\leftarrow \\frac{1}{m} \\sum_{i=1}^{m}(\\boldsymbol{x}^{(i)} - \\boldsymbol{\\mu}_\\mathcal{B})^2,\n",
    "$$\n",
    "\n",
    "\n",
    "$$\n",
    "\\hat{\\boldsymbol{x}}^{(i)} \\leftarrow \\frac{\\boldsymbol{x}^{(i)} - \\boldsymbol{\\mu}_\\mathcal{B}}{\\sqrt{\\boldsymbol{\\sigma}_\\mathcal{B}^2 + \\epsilon}},\n",
    "$$\n",
    "\n",
    "这⾥ϵ > 0是个很小的常数，保证分母大于0\n",
    "\n",
    "\n",
    "$$\n",
    "{\\boldsymbol{y}}^{(i)} \\leftarrow \\boldsymbol{\\gamma} \\odot\n",
    "\\hat{\\boldsymbol{x}}^{(i)} + \\boldsymbol{\\beta}.\n",
    "$$\n",
    "\n",
    "\n",
    "引入可学习参数：拉伸参数γ和偏移参数β。若$\\boldsymbol{\\gamma} = \\sqrt{\\boldsymbol{\\sigma}_\\mathcal{B}^2 + \\epsilon}$和$\\boldsymbol{\\beta} = \\boldsymbol{\\mu}_\\mathcal{B}$，批量归一化无效。这两个参数是为了保证归一化后的信息可以还原到以前的信息.打个比方 (如果信息太重要,归一化会损失很多信息.那么我们可以通过这两个参数还原归一化后信息.\n",
    "\n",
    "### 2.对卷积层做批量归⼀化\n",
    "位置：卷积计算之后、应⽤激活函数之前。  \n",
    "如果卷积计算输出多个通道，我们需要对这些通道的输出分别做批量归一化，且每个通道都拥有独立的拉伸和偏移参数。\n",
    "计算：对单通道，batchsize=m,卷积计算输出=pxq\n",
    "对该通道中m×p×q个元素同时做批量归一化,使用相同的均值和方差。\n",
    "\n",
    "### 3.预测时的批量归⼀化\n",
    "训练：以batch为单位,对每个batch计算均值和方差。  \n",
    "预测：用移动平均估算整个训练数据集的样本均值和方差。\n",
    "$\\color{#FF0000}{利用小批量上的均值和标准差，不断调整神经网络中间输出，从而使整个神经网络在各层的中间输出的数值更稳定。}$\n",
    "\n",
    "### 从零实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "39F35F2C0FE148B498B49A9E7D5A96C1",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "#目前GPU算力资源预计17日上线，在此之前本代码只能使用CPU运行。\n",
    "#考虑到本代码中的模型过大，CPU训练较慢，\n",
    "#我们还将代码上传了一份到 https://www.kaggle.com/boyuai/boyu-d2l-deepcnn\n",
    "#如希望提前使用gpu运行请至kaggle。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "96E72C7A9429485180F9B63F7FB07CB9",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import sys\n",
    "sys.path.append(\"/home/kesci/input/\") \n",
    "import d2lzh1981 as d2l\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
    "    # 判断当前模式是训练模式还是预测模式\n",
    "    if not is_training:\n",
    "        # 如果是在预测模式下，直接使用传入的移动平均所得的均值和方差\n",
    "        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
    "    else:\n",
    "        assert len(X.shape) in (2, 4)#判断是全连接层还是卷积层\n",
    "        if len(X.shape) == 2:\n",
    "            # 使用全连接层的情况，计算特征维上的均值和方差\n",
    "            mean = X.mean(dim=0)\n",
    "            #dim=0说明时在batchsize上做的平均\n",
    "            var = ((X - mean) ** 2).mean(dim=0)\n",
    "        else:\n",
    "            # 使用二维卷积层的情况，计算通道维上（axis=1）的均值和方差。这里我们需要保持\n",
    "            # X的形状以便后面可以做广播运算\n",
    "            #老师讲下面这行代码是先对第三维求均值，再将该均值在第二维上求均值，再在第0维上求均值的原因是什么呢？？？？？\n",
    "            #size是否有参数（即1*n前面的这个1）取决于 keepdim=True\n",
    "            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "        # 训练模式下用当前的均值和方差做标准化\n",
    "        X_hat = (X - mean) / torch.sqrt(var + eps)\n",
    "        # 更新移动平均的均值和方差\n",
    "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
    "        moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
    "    Y = gamma * X_hat + beta  # 拉伸和偏移？？？？？？为什么这么做？？？？\n",
    "    return Y, moving_mean, moving_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "04A65ABC9F20404E83F166F43226873D",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BatchNorm(nn.Module):\n",
    "    def __init__(self, num_features, num_dims):\n",
    "        super(BatchNorm, self).__init__()\n",
    "        if num_dims == 2:\n",
    "            shape = (1, num_features) #全连接层输出神经元\n",
    "        else:\n",
    "            shape = (1, num_features, 1, 1)  #通道数\n",
    "        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成0和1\n",
    "        self.gamma = nn.Parameter(torch.ones(shape))\n",
    "        self.beta = nn.Parameter(torch.zeros(shape))\n",
    "        # 不参与求梯度和迭代的变量，全在内存上初始化成0\n",
    "        self.moving_mean = torch.zeros(shape)\n",
    "        self.moving_var = torch.zeros(shape)\n",
    "\n",
    "    def forward(self, X):\n",
    "        # 如果X不在内存上，将moving_mean和moving_var复制到X所在显存上\n",
    "        if self.moving_mean.device != X.device:\n",
    "            self.moving_mean = self.moving_mean.to(X.device)\n",
    "            self.moving_var = self.moving_var.to(X.device)\n",
    "        # 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成false\n",
    "        Y, self.moving_mean, self.moving_var = batch_norm(self.training, \n",
    "            X, self.gamma, self.beta, self.moving_mean,\n",
    "            self.moving_var, eps=1e-5, momentum=0.9)\n",
    "        return Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4FE611706B7F4C1B8B57182BC1CE450B",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 基于LeNet的应用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "1B18E8F1B4C843E6852A0B50CFFD9C53",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (1): BatchNorm()\n",
      "  (2): Sigmoid()\n",
      "  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (4): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (5): BatchNorm()\n",
      "  (6): Sigmoid()\n",
      "  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (8): FlattenLayer()\n",
      "  (9): Linear(in_features=256, out_features=120, bias=True)\n",
      "  (10): BatchNorm()\n",
      "  (11): Sigmoid()\n",
      "  (12): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (13): BatchNorm()\n",
      "  (14): Sigmoid()\n",
      "  (15): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n",
    "            #下面6代表通道数，num_dim=4代表卷积\n",
    "            BatchNorm(6, num_dims=4),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2), # kernel_size, stride\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            BatchNorm(16, num_dims=4),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            d2l.FlattenLayer(),\n",
    "            nn.Linear(16*4*4, 120),\n",
    "            BatchNorm(120, num_dims=2),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(120, 84),\n",
    "            BatchNorm(84, num_dims=2),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(84, 10)\n",
    "        )\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "F11883676ED64A2DB05BC3480AF28BFF",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "#batch_size = 256  \n",
    "##cpu要调小batchsize\n",
    "batch_size=16\n",
    "\n",
    "def load_data_fashion_mnist(batch_size, resize=None, root='/home/kesci/input/FashionMNIST2065'):\n",
    "    \"\"\"Download the fashion mnist dataset and then load into memory.\"\"\"\n",
    "    trans = []\n",
    "    if resize:\n",
    "        trans.append(torchvision.transforms.Resize(size=resize))\n",
    "    trans.append(torchvision.transforms.ToTensor())\n",
    "    \n",
    "    transform = torchvision.transforms.Compose(trans)\n",
    "    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)\n",
    "    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)\n",
    "\n",
    "    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=2)\n",
    "    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=2)\n",
    "\n",
    "    return train_iter, test_iter\n",
    "train_iter, test_iter = load_data_fashion_mnist(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "9F9DC5C22F5942A48A8A4B656B33D543",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AC692570D35A492BB82A50E9BBBBB598",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 简洁实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EAC5ACDF548B4831992A653DF4FD4348",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n",
    "            #下面这个是内置函数，6代表输入通道数，一问你BatchNorm2d代表全连接层，所以不用输入4这个参数\n",
    "            nn.BatchNorm2d(6),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2), # kernel_size, stride\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            d2l.FlattenLayer(),\n",
    "            nn.Linear(16*4*4, 120),\n",
    "            #BatchNorm1d代表全连接层\n",
    "            nn.BatchNorm1d(120),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(120, 84),\n",
    "            nn.BatchNorm1d(84),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(84, 10)\n",
    "        )\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "50F442253C51495181611BB088576C75",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# 残差网络（ResNet）\n",
    "深度学习的问题：深度CNN网络达到一定深度后再一味地增加层数并不能带来进一步地分类性能提高，反而会招致网络收敛变得更慢，准确率也变得更差。\n",
    "### 残差块（Residual Block）\n",
    "恒等映射：  \n",
    "左边：f(x)=x                                                  \n",
    "右边：f(x)-x=0 （易于捕捉恒等映射的细微波动）\n",
    "$\\color{#FF0000}{上面的这个恒等映射的原因不清楚，为什么是这样子的呢？？}$\n",
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5l8lhnot4.png?imageView2/0/w/600/h/600)\n",
    "\n",
    "在残差块中，输⼊可通过跨层的数据线路更快 地向前传播。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "3F859601FD264D1A8D5E5E51B9D83D0D",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Residual(nn.Module):  # 本类已保存在d2lzh_pytorch包中方便以后使用\n",
    "    #可以设定输出通道数、是否使用额外的1x1卷积层来修改通道数以及卷积层的步幅。\n",
    "    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):\n",
    "        super(Residual, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)\n",
    "        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)\n",
    "        if use_1x1conv:\n",
    "            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
    "        else:\n",
    "            self.conv3 = None\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "\n",
    "    def forward(self, X):\n",
    "        Y = F.relu(self.bn1(self.conv1(X)))\n",
    "        Y = self.bn2(self.conv2(Y))\n",
    "        if self.conv3:\n",
    "            X = self.conv3(X)\n",
    "        return F.relu(Y + X)\n",
    "    #上面这个forward的函数是怎么回事呢？？？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "2E2CF8846F5C43618B6DB004FA95BC0E",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 6, 6])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = Residual(3, 3)\n",
    "X = torch.rand((4, 3, 6, 6))\n",
    "blk(X).shape # torch.Size([4, 3, 6, 6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "C8F6DFFB58344B669C24A18721B8A66F",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 6, 3, 3])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = Residual(3, 6, use_1x1conv=True, stride=2)#输入通道数不一样\n",
    "blk(X).shape # torch.Size([4, 6, 3, 3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CEFB948C07AE428D8F5F756CDE556CF9",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### ResNet模型\n",
    "卷积(64,7x7,3)  \n",
    "批量一体化  \n",
    "最大池化(3x3,2)  \n",
    "\n",
    "残差块x4 (通过步幅为2的残差块在每个模块之间减小高和宽)\n",
    "\n",
    "全局平均池化\n",
    "\n",
    "全连接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "6117AE144A6B4EF58B382D31F1C2F760",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "        nn.BatchNorm2d(64), \n",
    "        nn.ReLU(),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "D1E8052B1A8242F285BDD3E54418419B",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def resnet_block(in_channels, out_channels, num_residuals, first_block=False):\n",
    "    if first_block:\n",
    "        assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致\n",
    "    blk = []\n",
    "    for i in range(num_residuals):\n",
    "        if i == 0 and not first_block:\n",
    "            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))\n",
    "        else:\n",
    "            blk.append(Residual(out_channels, out_channels))\n",
    "    return nn.Sequential(*blk)\n",
    "\n",
    "net.add_module(\"resnet_block1\", resnet_block(64, 64, 2, first_block=True))\n",
    "net.add_module(\"resnet_block2\", resnet_block(64, 128, 2))\n",
    "net.add_module(\"resnet_block3\", resnet_block(128, 256, 2))\n",
    "net.add_module(\"resnet_block4\", resnet_block(256, 512, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "93D7FFC7581F438A82E92D85E3F6D24E",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "net.add_module(\"global_avg_pool\", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)\n",
    "net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "EE1FD4E1510941AB98B3C8840DA47E44",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  output shape:\t torch.Size([1, 64, 112, 112])\n",
      "1  output shape:\t torch.Size([1, 64, 112, 112])\n",
      "2  output shape:\t torch.Size([1, 64, 112, 112])\n",
      "3  output shape:\t torch.Size([1, 64, 56, 56])\n",
      "resnet_block1  output shape:\t torch.Size([1, 64, 56, 56])\n",
      "resnet_block2  output shape:\t torch.Size([1, 128, 28, 28])\n",
      "resnet_block3  output shape:\t torch.Size([1, 256, 14, 14])\n",
      "resnet_block4  output shape:\t torch.Size([1, 512, 7, 7])\n",
      "global_avg_pool  output shape:\t torch.Size([1, 512, 1, 1])\n",
      "fc  output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand((1, 1, 224, 224))#自己推演一下每层的大小。\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, ' output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "7A68C8B9480C4DB8935436EC136068F1",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "610FDA7246E64E6AAA8669EEDEA33BB1",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# 稠密连接网络（DenseNet）\n",
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5l8mi78yz.png?imageView2/0/w/600/h/600)\n",
    "\n",
    "###主要构建模块：  \n",
    "稠密块（dense block）： 定义了输入和输出是如何连结的。  \n",
    "过渡层（transition layer）：用来控制通道数，使之不过大。\n",
    "### 稠密块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "644303248ACD46C9886553FB41750A4F",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def conv_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(nn.BatchNorm2d(in_channels), \n",
    "                        nn.ReLU(),\n",
    "                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
    "    return blk\n",
    "\n",
    "class DenseBlock(nn.Module):\n",
    "    def __init__(self, num_convs, in_channels, out_channels):\n",
    "        super(DenseBlock, self).__init__()\n",
    "        net = []\n",
    "        for i in range(num_convs):\n",
    "            in_c = in_channels + i * out_channels\n",
    "            net.append(conv_block(in_c, out_channels))\n",
    "        self.net = nn.ModuleList(net)\n",
    "        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数\n",
    "\n",
    "    def forward(self, X):\n",
    "        for blk in self.net:\n",
    "            Y = blk(X)\n",
    "            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "72F1294D9DDD454D9831BCA94764B9DE",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 23, 8, 8])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = DenseBlock(2, 3, 10)\n",
    "X = torch.rand(4, 3, 8, 8)\n",
    "Y = blk(X)\n",
    "Y.shape # torch.Size([4, 23, 8, 8])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "23FDE43573EA40389189D2D4659A59D9",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 过渡层\n",
    "$1\\times1$卷积层：来减小通道数  \n",
    "步幅为2的平均池化层：减半高和宽"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "ED891AD5805040F6A93AD06B1E7C0FA8",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10, 4, 4])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def transition_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "            nn.BatchNorm2d(in_channels), \n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(in_channels, out_channels, kernel_size=1),\n",
    "            nn.AvgPool2d(kernel_size=2, stride=2))\n",
    "    return blk\n",
    "\n",
    "blk = transition_block(23, 10)\n",
    "blk(Y).shape # torch.Size([4, 10, 4, 4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AE6DADC26CE04D2C86F49867321D491F",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### DenseNet模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "E32E30B6657E4CB5B3027B663927AA83",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "        nn.BatchNorm2d(64), \n",
    "        nn.ReLU(),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "8464B9866CBF4EB280381A63A9F4557E",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_channels, growth_rate = 64, 32  # num_channels为当前的通道数，growth_rate是每经过一个卷积层增长的通道数（为32）\n",
    "num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
    "\n",
    "for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    DB = DenseBlock(num_convs, num_channels, growth_rate)\n",
    "    net.add_module(\"DenseBlosk_%d\" % i, DB)\n",
    "    # 上一个稠密块的输出通道数\n",
    "    num_channels = DB.out_channels\n",
    "    # 在稠密块之间加入通道数减半的过渡层\n",
    "    if i != len(num_convs_in_dense_blocks) - 1:\n",
    "        net.add_module(\"transition_block_%d\" % i, transition_block(num_channels, num_channels // 2))\n",
    "        num_channels = num_channels // 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "76C89C4331674B8C8F531F4157CACAB1",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "1  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "2  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "3  output shape:\t torch.Size([1, 64, 24, 24])\n",
      "DenseBlosk_0  output shape:\t torch.Size([1, 192, 24, 24])\n",
      "transition_block_0  output shape:\t torch.Size([1, 96, 12, 12])\n",
      "DenseBlosk_1  output shape:\t torch.Size([1, 224, 12, 12])\n",
      "transition_block_1  output shape:\t torch.Size([1, 112, 6, 6])\n",
      "DenseBlosk_2  output shape:\t torch.Size([1, 240, 6, 6])\n",
      "transition_block_2  output shape:\t torch.Size([1, 120, 3, 3])\n",
      "DenseBlosk_3  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "BN  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "relu  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "global_avg_pool  output shape:\t torch.Size([1, 248, 1, 1])\n",
      "fc  output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "net.add_module(\"BN\", nn.BatchNorm2d(num_channels))\n",
    "net.add_module(\"relu\", nn.ReLU())\n",
    "net.add_module(\"global_avg_pool\", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)\n",
    "net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) \n",
    "\n",
    "X = torch.rand((1, 1, 96, 96))\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, ' output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "id": "F70449A99F984956B8BEC295A53E0A99",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "#batch_size = 256\n",
    "batch_size=16\n",
    "# 如出现“out of memory”的报错信息，可减小batch_size或resize\n",
    "train_iter, test_iter =load_data_fashion_mnist(batch_size, resize=96)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
